-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][SMT] restore custom builder for forall/exists #135470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][SMT] restore custom builder for forall/exists #135470
Conversation
|
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesThis reverts commit 54e70ac. The necessary change was to explicitly Full diff: https://github.com/llvm/llvm-project/pull/135470.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
index af73955caee54..1872c00b74f1a 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -448,6 +448,18 @@ class QuantifierOp<string mnemonic> : SMTOp<mnemonic, [
VariadicRegion<SizedRegion<1>>:$patterns);
let results = (outs BoolType:$result);
+ let builders = [
+ OpBuilder<(ins
+ "TypeRange":$boundVarTypes,
+ "function_ref<Value(OpBuilder &, Location, ValueRange)>":$bodyBuilder,
+ CArg<"std::optional<ArrayRef<StringRef>>", "std::nullopt">:$boundVarNames,
+ CArg<"function_ref<ValueRange(OpBuilder &, Location, ValueRange)>",
+ "{}">:$patternBuilder,
+ CArg<"uint32_t", "0">:$weight,
+ CArg<"bool", "false">:$noPattern)>
+ ];
+ let skipDefaultBuilders = true;
+
let assemblyFormat = [{
($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)?
attr-dict-with-keyword $body (`patterns` $patterns^)?
diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
index 604dd26da1982..8977a3abc125d 100644
--- a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
+++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
@@ -432,6 +432,16 @@ LogicalResult ForallOp::verifyRegions() {
return verifyQuantifierRegions(*this);
}
+void ForallOp::build(
+ OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+ function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+ std::optional<ArrayRef<StringRef>> boundVarNames,
+ function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+ uint32_t weight, bool noPattern) {
+ buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+ boundVarNames, patternBuilder, weight, noPattern);
+}
+
//===----------------------------------------------------------------------===//
// ExistsOp
//===----------------------------------------------------------------------===//
@@ -448,5 +458,15 @@ LogicalResult ExistsOp::verifyRegions() {
return verifyQuantifierRegions(*this);
}
+void ExistsOp::build(
+ OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+ function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+ std::optional<ArrayRef<StringRef>> boundVarNames,
+ function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+ uint32_t weight, bool noPattern) {
+ buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+ boundVarNames, patternBuilder, weight, noPattern);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/SMT/IR/SMT.cpp.inc"
diff --git a/mlir/unittests/Dialect/SMT/CMakeLists.txt b/mlir/unittests/Dialect/SMT/CMakeLists.txt
index 86e16d6194ea9..a1331467febaa 100644
--- a/mlir/unittests/Dialect/SMT/CMakeLists.txt
+++ b/mlir/unittests/Dialect/SMT/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRSMTTests
AttributeTest.cpp
+ QuantifierTest.cpp
TypeTest.cpp
)
diff --git a/mlir/unittests/Dialect/SMT/QuantifierTest.cpp b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp
new file mode 100644
index 0000000000000..328dba75d8655
--- /dev/null
+++ b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp
@@ -0,0 +1,199 @@
+//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace smt;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Test custom builders of ExistsOp
+//===----------------------------------------------------------------------===//
+
+TEST(QuantifierTest, ExistsBuilderWithPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ExistsOp existsOp = builder.create<ExistsOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ std::nullopt,
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return boundVars;
+ },
+ /*weight=*/2);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ existsOp.print(stream);
+
+ ASSERT_STREQ(
+ stream.str().str().c_str(),
+ "%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
+ "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
+ "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
+
+ existsOp->destroy();
+}
+
+TEST(QuantifierTest, ExistsBuilderNoPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ExistsOp existsOp = builder.create<ExistsOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ existsOp.print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
+ "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
+ "smt.yield %0 : !smt.bool\n}\n");
+
+ existsOp->destroy();
+}
+
+TEST(QuantifierTest, ExistsBuilderDefault) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ExistsOp existsOp = builder.create<ExistsOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"});
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ existsOp.print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
+ "%0 : !smt.bool\n}\n");
+
+ existsOp->destroy();
+}
+
+//===----------------------------------------------------------------------===//
+// Test custom builders of ForallOp
+//===----------------------------------------------------------------------===//
+
+TEST(QuantifierTest, ForallBuilderWithPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ForallOp forallOp = builder.create<ForallOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return boundVars;
+ },
+ /*weight=*/2);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ forallOp.print(stream);
+
+ ASSERT_STREQ(
+ stream.str().str().c_str(),
+ "%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
+ "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
+ "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
+
+ forallOp->destroy();
+}
+
+TEST(QuantifierTest, ForallBuilderNoPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ForallOp forallOp = builder.create<ForallOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ forallOp.print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
+ "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
+ "smt.yield %0 : !smt.bool\n}\n");
+
+ forallOp->destroy();
+}
+
+TEST(QuantifierTest, ForallBuilderDefault) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ForallOp forallOp = builder.create<ForallOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ std::nullopt);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ forallOp.print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.forall {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
+ "%0 : !smt.bool\n}\n");
+
+ forallOp->destroy();
+}
+
+} // namespace
|
…an memory leak" This reverts commit 54e70ac.
43a1b45 to
a07ccba
Compare
math-fehr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
I guess CIRCT doesn't have an ASAN CI so this was missed?
Ya I guess not - maybe we should add that? I'll put it on my ever-growing to-do list lol |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/52/builds/7530 Here is the relevant piece of the build log for the reference |
|
Buildbot error seems unrelated? |
|
yea clang-repl has been acting up for a while now 🤷 |
|
Thanks for fixing this! |
This reverts commit 54e70ac which itself fixed an asan leak from the original upstreaming commit. The leak was due to op allocations not being
freeed.The necessary change was to explicitly->destroy()the ops at the end of the tests. I believe this is because the rewriter used in the tests doesn't actually insert them into a module and so without an explicit->destroy()no bookkeeping process is able to take care of them.The necessary change was to use
OwningOpRefwhich callsop->erase()in its own destructor.